{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as  plt\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "import d2lzh_pytorch as d21"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_inputs,num_outputs,num_hiddens=784,10,256\n",
    "net=nn.Sequential(\n",
    "    d21.FlattenLayer(),\n",
    "    nn.Linear(num_inputs,num_hiddens),\n",
    "    nn.Linear(num_hiddens,num_outputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "epoch 1, loss nan, train acc 0.108, test acc 0.100\nepoch 2, loss nan, train acc 0.100, test acc 0.100\nepoch 3, loss nan, train acc 0.100, test acc 0.100\nepoch 4, loss nan, train acc 0.100, test acc 0.100\nepoch 5, loss nan, train acc 0.100, test acc 0.100\n"
    }
   ],
   "source": [
    "batch_size=256\n",
    "train_iter,test_iter=d21.load_data_fashion_mnist(batch_size)\n",
    "loss=torch.nn.CrossEntropyLoss()\n",
    "optimizer=torch.optim.SGD(net.parameters(),lr=0.5)\n",
    "num_epochs=5\n",
    "d21.train_ch3(net,train_iter,test_iter,loss,num_rpochs,batch_size,None,None,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python37464bitbasecondab04fb1e447c742be8678bc28fb07e8e2",
   "display_name": "Python 3.7.4 64-bit ('base': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}